{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IERG 5350 Assignment 2: Model-free Tabular RL\n",
    "\n",
    "*2021-2022 1st term, IERG 5350: Reinforcement Learning. Department of Information Engineering, The Chinese University of Hong Kong. Course Instructor: Professor ZHOU Bolei. Assignment author: PENG Zhenghao.*\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| Student Name | Student ID |\n",
    "| :----: | :----: |\n",
    "| TYPE_YOUR_NAME_HERE | TYPE_YOUR_STUDENT_ID_HERE |\n",
    "\n",
    "------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Welecome to the assignment 2 of our RL course. The objective of this assignment is for you to understand the classic methods used in tabular reinforcement learning. \n",
    "\n",
    "You need to go through this self-contained notebook, which contains dozens of **TODOs** in part of the cells and has special `[TODO]` signs. You need to finish all TODOs. \n",
    "\n",
    "Please report any code bugs to us via Github issues.\n",
    "\n",
    "Before you get start, remember to follow the instruction at https://github.com/cuhkrlcourse/ierg5350-assignment-2021 to setup your environment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 1: SARSA\n",
    "\n",
    "(30/100 points)\n",
    "\n",
    "You have noticed that in Assignment 1 - Section 2, we always use the function `trainer._get_transitions()` to get the transition dynamics of the environment, while never call `trainer.env.step()` to really interact with the environment by applying actions. We need to access the internal dynamics of the environment and have somebody implement `_get_transitions` for us. \n",
    "\n",
    "However, this is not feasible in many cases, especially in some real-world tasks like autonomous driving where the transition dynamics is unknown.\n",
    "\n",
    "In this section, we will introduce the model-free family of algorithms that do not require to know the transitions: they only get information from `env.step(action)` and collect information by interacting with the environment.\n",
    "\n",
    "We will continue to use the `TabularRLTrainerAbstract` class to implement algorithms, but remember you should not call `trainer._get_transitions()` anymore.\n",
    "\n",
    "We will use a simpler environment `FrozenLakerNotSlippery-v0` to conduct experiments, which has a `4 X 4` grids and is deterministic. This is because, in a model-free setting, it's extremely hard for a random agent to achieve the goal for the first time. To reduce the time of experiments, we choose to use a simpler environment. In the bonus section, you can try out model-free RL on `FrozenLake8x8-v1` to see what will happen. \n",
    "\n",
    "Now go through each section and start your coding!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Recall the idea of SARSA: it's an on-policy TD control method, which has distinct features compared to policy iteration and value iteration methods in the training process:\n",
    "\n",
    "1. It maintains a state-action pair value function $Q(s_t, a_t) = E \\sum_{i=0} \\gamma^{t+i} r_{t+i}$ to approximate the Q value.\n",
    "2. It does not require to know the internal dynamics of the environment.\n",
    "3. It use an epsilon-greedy strategy to balance exploration and exploitation.\n",
    "\n",
    "In SARSA algorithm, we update the Q value via TD error: \n",
    "\n",
    "$$TD(s_t, a_t) = r(s_t, a_t) + \\gamma Q(s_{t+1}, a_{t+1}) - Q(s_t, a_t),$$\n",
    "\n",
    "wherein we run the policy to get the next action $a_{t+1} = Policy(s_{t+1})$. \n",
    "That's why we call SARSA an on-policy algorithm, since it use the current policy to evaluate Q value.\n",
    "\n",
    "$$Q^{new}(s_t, a_t) = Q(s_t, a_t) + \\alpha TD(s_t, a_t),$$\n",
    "\n",
    "wherein $\\alpha$ is the learning rate, a hyper-parameter provided by the user.\n",
    "\n",
    "Now please go through the codes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "# Import some packages that we need to use\n",
    "from utils import *\n",
    "import gym\n",
    "import numpy as np\n",
    "from collections import deque"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve the TODOs and remove `pass`\n",
    "\n",
    "def _render_helper(env):\n",
    "    env.render()\n",
    "    wait(sleep=0.2)\n",
    "\n",
    "\n",
    "def evaluate(policy, num_episodes, seed=0, env_name='FrozenLake8x8-v1', render=False):\n",
    "    \"\"\"[TODO] You need to implement this function by yourself. It\n",
    "    evaluate the given policy and return the mean episode reward.\n",
    "    We use `seed` argument for testing purpose.\n",
    "    You should pass the tests in the next cell.\n",
    "\n",
    "    :param policy: a function whose input is an interger (observation)\n",
    "    :param num_episodes: number of episodes you wish to run\n",
    "    :param seed: an interger, used for testing.\n",
    "    :param env_name: the name of the environment\n",
    "    :param render: a boolean flag. If true, please call _render_helper\n",
    "    function.\n",
    "    :return: the averaged episode reward of the given policy.\n",
    "    \"\"\"\n",
    "\n",
    "    # Create environment (according to env_name, we will use env other than 'FrozenLake8x8-v0')\n",
    "    env = gym.make(env_name)\n",
    "\n",
    "    # Seed the environment\n",
    "    env.seed(seed)\n",
    "\n",
    "    # Build inner loop to run.\n",
    "    # For each episode, do not set the limit.\n",
    "    # Only terminate episode (reset environment) when done = True.\n",
    "    # The episode reward is the sum of all rewards happen within one episode.\n",
    "    # Call the helper function `render(env)` to render\n",
    "    rewards = []\n",
    "    for i in range(num_episodes):\n",
    "        # reset the environment\n",
    "        obs = env.reset()\n",
    "        act = policy(obs)\n",
    "        \n",
    "        ep_reward = 0\n",
    "        while True:\n",
    "            # [TODO] run the environment and terminate it if done, collect the\n",
    "            # reward at each step and sum them to the episode reward.\n",
    "            obs, reward, done, info = env.step(act)\n",
    "            ep_reward += reward\n",
    "            act = policy(obs)\n",
    "            if done:\n",
    "                break\n",
    "        \n",
    "        rewards.append(ep_reward)\n",
    "\n",
    "    return np.mean(rewards)\n",
    "\n",
    "# [TODO] Run next cell to test your implementation!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "class TabularRLTrainerAbstract:\n",
    "    \"\"\"This is the abstract class for tabular RL trainer. We will inherent the specify \n",
    "    algorithm's trainer from this abstract class, so that we can reuse the codes like\n",
    "    getting the dynamic of the environment (self._get_transitions()) or rendering the\n",
    "    learned policy (self.render()).\"\"\"\n",
    "    \n",
    "    def __init__(self, env_name='FrozenLake8x8-v1', model_based=True):\n",
    "        self.env_name = env_name\n",
    "        self.env = gym.make(self.env_name)\n",
    "        self.action_dim = self.env.action_space.n\n",
    "        self.obs_dim = self.env.observation_space.n\n",
    "        \n",
    "        self.model_based = model_based\n",
    "\n",
    "    def _get_transitions(self, state, act):\n",
    "        \"\"\"Query the environment to get the transition probability,\n",
    "        reward, the next state, and done given a pair of state and action.\n",
    "        We implement this function for you. But you need to know the \n",
    "        return format of this function.\n",
    "        \"\"\"\n",
    "        self._check_env_name()\n",
    "        assert self.model_based, \"You should not use _get_transitions in \" \\\n",
    "            \"model-free algorithm!\"\n",
    "        \n",
    "        # call the internal attribute of the environments.\n",
    "        # `transitions` is a list contain all possible next states and the \n",
    "        # probability, reward, and termination indicater corresponding to it\n",
    "        transitions = self.env.env.P[state][act]\n",
    "\n",
    "        # Given a certain state and action pair, it is possible\n",
    "        # to find there exist multiple transitions, since the \n",
    "        # environment is not deterministic.\n",
    "        # You need to know the return format of this function: a list of dicts\n",
    "        ret = []\n",
    "        for prob, next_state, reward, done in transitions:\n",
    "            ret.append({\n",
    "                \"prob\": prob,\n",
    "                \"next_state\": next_state,\n",
    "                \"reward\": reward,\n",
    "                \"done\": done\n",
    "            })\n",
    "        return ret\n",
    "    \n",
    "    def _check_env_name(self):\n",
    "        assert self.env_name.startswith('FrozenLake')\n",
    "\n",
    "    def print_table(self):\n",
    "        \"\"\"print beautiful table, only work for FrozenLake8X8-v1 env. We \n",
    "        write this function for you.\"\"\"\n",
    "        self._check_env_name()\n",
    "        print_table(self.table)\n",
    "\n",
    "    def train(self):\n",
    "        \"\"\"Conduct one iteration of learning.\"\"\"\n",
    "        raise NotImplementedError(\"You need to override the \"\n",
    "                                  \"Trainer.train() function.\")\n",
    "\n",
    "    def evaluate(self):\n",
    "        \"\"\"Use the function you write to evaluate current policy.\n",
    "        Return the mean episode reward of 1000 episodes when seed=0.\"\"\"\n",
    "        result = evaluate(self.policy, 1000, env_name=self.env_name)\n",
    "        return result\n",
    "\n",
    "    def render(self):\n",
    "        \"\"\"Reuse your evaluate function, render current policy \n",
    "        for one episode when seed=0\"\"\"\n",
    "        evaluate(self.policy, 1, render=True, env_name=self.env_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve the TODOs and remove `pass`\n",
    "\n",
    "class SARSATrainer(TabularRLTrainerAbstract):\n",
    "    def __init__(self,\n",
    "                 gamma=1.0,\n",
    "                 eps=0.1,\n",
    "                 learning_rate=1.0,\n",
    "                 max_episode_length=100,\n",
    "                 env_name='FrozenLake8x8-v1'\n",
    "                 ):\n",
    "        super(SARSATrainer, self).__init__(env_name, model_based=False)\n",
    "\n",
    "        # discount factor\n",
    "        self.gamma = gamma\n",
    "\n",
    "        # epsilon-greedy exploration policy parameter\n",
    "        self.eps = eps\n",
    "\n",
    "        # maximum steps in single episode\n",
    "        self.max_episode_length = max_episode_length\n",
    "\n",
    "        # the learning rate\n",
    "        self.learning_rate = learning_rate\n",
    "\n",
    "        # build the Q table\n",
    "        # [TODO] uncomment the next line, pay attention to the shape\n",
    "        self.table = np.zeros((self.obs_dim, self.action_dim))\n",
    "\n",
    "    def policy(self, obs):\n",
    "        \"\"\"Implement epsilon-greedy policy\n",
    "\n",
    "        It is a function that take an integer (state / observation)\n",
    "        as input and return an interger (action).\n",
    "        \"\"\"\n",
    "\n",
    "        # [TODO] You need to implement the epsilon-greedy policy here.\n",
    "        # hint: We have self.eps probability to choose a uniformly random\n",
    "        #  action in range [0, 1, .., self.action_dim - 1], \n",
    "        #  otherwise choose action that maximize the Q value\n",
    "        policy_type = np.random.choice([0, 1], p=[self.eps, 1-self.eps])\n",
    "        if policy_type:\n",
    "            return np.argmax(self.table[obs])\n",
    "        else:\n",
    "            return np.random.choice(list(range(self.action_dim)))\n",
    "\n",
    "    def train(self):\n",
    "        \"\"\"Conduct one iteration of learning.\"\"\"\n",
    "        # [TODO] Q table may be need to be reset to zeros.\n",
    "        # if you think it should, than do it. If not, then move on.\n",
    "        pass\n",
    "        # No, we should do nothing.\n",
    "\n",
    "        obs = self.env.reset()\n",
    "        act = self.policy(obs)\n",
    "        for t in range(self.max_episode_length):\n",
    "\n",
    "\n",
    "            next_obs, reward, done, _ = self.env.step(act)\n",
    "            next_act = self.policy(next_obs)\n",
    "\n",
    "            # [TODO] compute the TD error, based on the next observation and\n",
    "            #  action.\n",
    "            td_error = reward + self.gamma * self.table[next_obs][next_act] - self.table[obs][act]\n",
    "\n",
    "            # [TODO] compute the new Q value\n",
    "            # hint: use TD error, self.learning_rate and old Q value\n",
    "            new_value = self.table[obs][act] + self.learning_rate * td_error\n",
    "\n",
    "            self.table[obs][act] = new_value\n",
    "\n",
    "            # [TODO] Implement (1) break if done. (2) update obs for next \n",
    "            #  self.policy(obs) call\n",
    "            obs = next_obs\n",
    "            act = next_act\n",
    "            if done:\n",
    "                break\n",
    "\n",
    "# [TODO] run the next cell to check your code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you have finished the SARSA trainer. To make sure your implementation of epsilon-greedy strategy is correct, please run the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Policy Test passed!\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "# set eps = 0 to disable exploration.\n",
    "test_trainer = SARSATrainer(eps=0.0)\n",
    "test_trainer.table.fill(0)\n",
    "\n",
    "# set the Q value of (obs 0, act 3) to 100, so that it should be taken by \n",
    "# policy.\n",
    "test_obs = 0\n",
    "test_act = test_trainer.action_dim - 1\n",
    "test_trainer.table[test_obs][test_act] = 100\n",
    "\n",
    "# assertion\n",
    "assert test_trainer.policy(test_obs) == test_act, \\\n",
    "    \"Your action is wrong! Should be {} but get {}.\".format(\n",
    "        test_act, test_trainer.policy(test_obs))\n",
    "\n",
    "# delete trainer\n",
    "del test_trainer\n",
    "\n",
    "# set eps = 0 to disable exploitation.\n",
    "test_trainer = SARSATrainer(eps=1.0)\n",
    "test_trainer.table.fill(0)\n",
    "\n",
    "act_set = set()\n",
    "for i in range(100):\n",
    "    act_set.add(test_trainer.policy(0))\n",
    "\n",
    "# assertion\n",
    "assert len(act_set) > 1, (\"You sure your uniformaly action selection mechanism\"\n",
    "                          \" is working? You only take action {} when \"\n",
    "                          \"observation is 0, though we run trainer.policy() \"\n",
    "                          \"for 100 times.\".format(act_set))\n",
    "# delete trainer\n",
    "del test_trainer\n",
    "\n",
    "print(\"Policy Test passed!\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now run the next cells to see the result. \n",
    "\n",
    "\n",
    "Note that we use the non-slippery version of a small frozen lake environment `FrozenLakeNotSlipppery-v0` (this is not a ready Gym environment, see `utils.py` for details). This is because, in the model-free setting, it's extremely hard to access the goal for the first time (you should already know that if you watch the agent randomly acting in Assignment 1 - Section 1)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve TODO\n",
    "\n",
    "# Managing configurations of your experiments is important for your research.\n",
    "default_sarsa_config = dict(\n",
    "    max_iteration=20000,\n",
    "    max_episode_length=200,\n",
    "    learning_rate=0.01,\n",
    "    evaluate_interval=1000,\n",
    "    gamma=0.8,\n",
    "    eps=0.3,\n",
    "    env_name='FrozenLakeNotSlippery-v0'\n",
    ")\n",
    "\n",
    "\n",
    "def sarsa(train_config=None):\n",
    "    config = default_sarsa_config.copy()\n",
    "    if train_config is not None:\n",
    "        config.update(train_config)\n",
    "\n",
    "    trainer = SARSATrainer(\n",
    "        gamma=config['gamma'],\n",
    "        eps=config['eps'],\n",
    "        learning_rate=config['learning_rate'],\n",
    "        max_episode_length=config['max_episode_length'],\n",
    "        env_name=config['env_name']\n",
    "    )\n",
    "\n",
    "    for i in range(config['max_iteration']):\n",
    "        # train the agent\n",
    "        trainer.train()  # [TODO] please uncomment this line\n",
    "\n",
    "        # evaluate the result\n",
    "        if i % config['evaluate_interval'] == 0:\n",
    "            print(\n",
    "                \"[INFO]\\tIn {} iteration, current mean episode reward is {}.\"\n",
    "                \"\".format(i, trainer.evaluate()))\n",
    "\n",
    "    if trainer.evaluate() < 0.6:\n",
    "        print(\"We expect to get the mean episode reward greater than 0.6. \" \\\n",
    "        \"But you get: {}. Please check your codes.\".format(trainer.evaluate()))\n",
    "\n",
    "    return trainer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.004.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.661.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.66.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.655.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.674.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.678.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.659.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.677.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.649.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.641.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.642.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.677.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.675.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.678.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.653.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.642.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.671.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.648.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.684.\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "sarsa_trainer = sarsa()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== The state value for action 0 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.124|0.123|0.062|0.003|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.167|0.000|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.229|0.246|0.293|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.000|0.505|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 1 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.167|0.000|0.008|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.239|0.000|0.326|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.000|0.478|0.716|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.511|0.713|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 2 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.084|0.023|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.000|0.000|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.338|0.449|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.711|1.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 3 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.125|0.060|0.004|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.125|0.000|0.005|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.165|0.000|0.116|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.352|0.480|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "sarsa_trainer.print_table()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "sarsa_trainer.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you have finished the SARSA algorithm."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 2: Q-Learning\n",
    "(30/100 points)\n",
    "\n",
    "Q-learning is an off-policy algorithm who differs from SARSA in the computing of TD error. Instead of running policy to get `next_act` $a'$ and get the TD error by:\n",
    "\n",
    "$r + \\gamma Q(s', a') - Q(s, a), a' \\sim \\pi(\\cdot|s')$, \n",
    "\n",
    "in Q-learning we compute the TD error via:\n",
    "\n",
    "$r + \\gamma \\max_{a'} Q(s', a') - Q(s, a)$. \n",
    "\n",
    "The reason we call it \"off-policy\" is that the policy involves the computing of next-Q value is not the \"behavior policy\", instead, it is a \"optimal policy\" that always takes the best action given current Q values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve the TODOs and remove `pass`\n",
    "\n",
    "class QLearningTrainer(TabularRLTrainerAbstract):\n",
    "    def __init__(self,\n",
    "                 gamma=1.0,\n",
    "                 eps=0.1,\n",
    "                 learning_rate=1.0,\n",
    "                 max_episode_length=100,\n",
    "                 env_name='FrozenLake8x8-v1'\n",
    "                 ):\n",
    "        super(QLearningTrainer, self).__init__(env_name, model_based=False)\n",
    "        self.gamma = gamma\n",
    "        self.eps = eps\n",
    "        self.max_episode_length = max_episode_length\n",
    "        self.learning_rate = learning_rate\n",
    "\n",
    "        # build the Q table\n",
    "        self.table = np.zeros((self.obs_dim, self.action_dim))\n",
    "\n",
    "    def policy(self, obs):\n",
    "        \"\"\"Implement epsilon-greedy policy\n",
    "\n",
    "        It is a function that take an integer (state / observation)\n",
    "        as input and return an interger (action).\n",
    "        \"\"\"\n",
    "\n",
    "        # [TODO] You need to implement the epsilon-greedy policy here.\n",
    "        # hint: Just copy your codes in SARSATrainer.policy()\n",
    "        policy_type = np.random.choice([0, 1], p=[self.eps, 1-self.eps])\n",
    "        if policy_type:\n",
    "            return np.argmax(self.table[obs])\n",
    "        else:\n",
    "            return np.random.choice(list(range(self.action_dim)))\n",
    "\n",
    "    def train(self):\n",
    "        \"\"\"Conduct one iteration of learning.\"\"\"\n",
    "        # [TODO] Q table may be need to be reset to zeros.\n",
    "        # if you think it should, than do it. If not, then move on.\n",
    "        pass\n",
    "        # No, we should do nothing.\n",
    "\n",
    "        obs = self.env.reset()\n",
    "        for t in range(self.max_episode_length):\n",
    "            act = self.policy(obs)\n",
    "\n",
    "            next_obs, reward, done, _ = self.env.step(act)\n",
    "\n",
    "            # [TODO] compute the TD error, based on the next observation\n",
    "            # hint: we do not need next_act anymore.\n",
    "            td_error = reward + self.gamma * np.max(self.table[next_obs]) - self.table[obs][act]\n",
    "\n",
    "            # [TODO] compute the new Q value\n",
    "            # hint: use TD error, self.learning_rate and old Q value\n",
    "            new_value = self.table[obs][act] +self.learning_rate * td_error\n",
    "\n",
    "            self.table[obs][act] = new_value\n",
    "            obs = next_obs\n",
    "            if done:\n",
    "                break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve the TODO\n",
    "\n",
    "# Managing configurations of your experiments is important for your research.\n",
    "default_q_learning_config = dict(\n",
    "    max_iteration=20000,\n",
    "    max_episode_length=200,\n",
    "    learning_rate=0.01,\n",
    "    evaluate_interval=1000,\n",
    "    gamma=0.8,\n",
    "    eps=0.3,\n",
    "    env_name='FrozenLakeNotSlippery-v0'\n",
    ")\n",
    "\n",
    "\n",
    "def q_learning(train_config=None):\n",
    "    config = default_q_learning_config.copy()\n",
    "    if train_config is not None:\n",
    "        config.update(train_config)\n",
    "\n",
    "    trainer = QLearningTrainer(\n",
    "        gamma=config['gamma'],\n",
    "        eps=config['eps'],\n",
    "        learning_rate=config['learning_rate'],\n",
    "        max_episode_length=config['max_episode_length'],\n",
    "        env_name=config['env_name']\n",
    "    )\n",
    "\n",
    "    for i in range(config['max_iteration']):\n",
    "        # train the agent\n",
    "        trainer.train()  # [TODO] please uncomment this line\n",
    "\n",
    "        # evaluate the result\n",
    "        if i % config['evaluate_interval'] == 0:\n",
    "            print(\n",
    "                \"[INFO]\\tIn {} iteration, current mean episode reward is {}.\"\n",
    "                \"\".format(i, trainer.evaluate()))\n",
    "\n",
    "    if trainer.evaluate() < 0.6:\n",
    "        print(\"We expect to get the mean episode reward greater than 0.6. \" \\\n",
    "        \"But you get: {}. Please check your codes.\".format(trainer.evaluate()))\n",
    "\n",
    "    return trainer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.623.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.632.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.64.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.623.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.622.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.666.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.625.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.612.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.622.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.657.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.662.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.627.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.676.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.646.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.636.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.693.\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "q_learning_trainer = q_learning()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== The state value for action 0 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.262|0.262|0.178|0.019|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.328|0.000|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.410|0.410|0.512|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.000|0.640|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 1 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.328|0.000|0.072|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.410|0.000|0.640|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.000|0.640|0.800|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.553|0.800|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 2 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.210|0.077|0.001|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.000|0.000|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.512|0.640|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.800|1.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 3 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.262|0.153|0.010|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.262|0.000|0.053|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.328|0.000|0.512|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.423|0.640|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "q_learning_trainer.print_table()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "q_learning_trainer.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you have finished Q-Learning algorithm."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Section 3: Monte Carlo Control\n",
    "(40/100 points)\n",
    "\n",
    "In sections 1 and 2, we implement the on-policy and off-policy versions of the TD Learning algorithms. In this section, we will play with another branch of the model-free algorithm: Monte Carlo Control. You can refer to the 5.3 Monte Carlo Control section of the textbook \"Reinforcement Learning: An Introduction\" to learn the details of MC control.\n",
    "\n",
    "\n",
    "The basic idea of MC control is to compute the Q value (state-action value) directly from an episode, without using TD to fit the Q function. \n",
    "\n",
    "\n",
    "Concretely, we maintain a batch of lists (the total number of lists is `obs_dim * action_dim`), each elememnt of the batch is a list correspondent to a state-action pair. The list is used to store the previously happenning \"return\" of each state action pair. The \"return\" here is the discounted accumulative reward of the trajectories starting from the state-action pair: $Return(s_t, a_t) = \\sum_{i=0} \\gamma^{t+i} r_{t+i}$.\n",
    "\n",
    "For example, the batch might looks like:\n",
    "\n",
    "```\n",
    "[(state=\"in left upper corner\", action=\"turn right\") = [10.0, 20.0, 30.0],\n",
    " (state=..., action=...) = [previously recorded return ...],\n",
    "...\n",
    "]\n",
    "```\n",
    "\n",
    "\n",
    "We will use a dict `self.returns` to store all lists. The keys of the dict are tuples `(obs, act)` and the value of the dict `self.returns[(obs, act)]` is the list to store all returns of the trajectories that starts from `(obs, act)`. \n",
    "\n",
    "The key point of MC Control method is that we take the mean of this list (the mean of all previous returns) as the Q value of the corresponding state-action pair.\n",
    "In short, MC Control method uses a new way to estimate the values of state-action pairs.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve the TODOs and remove `pass`\n",
    "\n",
    "class MCControlTrainer(TabularRLTrainerAbstract):\n",
    "    def __init__(self,\n",
    "                 gamma=1.0,\n",
    "                 eps=0.3,\n",
    "                 max_episode_length=100,\n",
    "                 env_name='FrozenLake8x8-v1'\n",
    "                 ):\n",
    "        super(MCControlTrainer, self).__init__(env_name, model_based=False)\n",
    "        self.gamma = gamma\n",
    "        self.eps = eps\n",
    "        self.max_episode_length = max_episode_length\n",
    "\n",
    "        # build the dict of lists\n",
    "        self.returns = {}\n",
    "        for obs in range(self.obs_dim):\n",
    "            for act in range(self.action_dim):\n",
    "                self.returns[(obs, act)] = []\n",
    "\n",
    "        # build the Q table\n",
    "        self.table = np.zeros((self.obs_dim, self.action_dim))\n",
    "\n",
    "    def policy(self, obs):\n",
    "        \"\"\"Implement epsilon-greedy policy\n",
    "\n",
    "        It is a function that take an integer (state / observation)\n",
    "        as input and return an interger (action).\n",
    "        \"\"\"\n",
    "\n",
    "        # [TODO] You need to implement the epsilon-greedy policy here.\n",
    "        # hint: Just copy your codes in SARSATrainer.policy()\n",
    "        policy_type = np.random.choice([0, 1], p=[self.eps, 1-self.eps])\n",
    "        if policy_type:\n",
    "            return np.argmax(self.table[obs])\n",
    "        else:\n",
    "            return np.random.choice(list(range(self.action_dim)))\n",
    "            \n",
    "    def train(self):\n",
    "        \"\"\"Conduct one iteration of learning.\"\"\"\n",
    "        observations = []\n",
    "        actions = []\n",
    "        rewards = []\n",
    "\n",
    "        # [TODO] rollout for one episode, store data in three lists create \n",
    "        #  above.\n",
    "        # hint: we do not need to store next observation.\n",
    "        obs = self.env.reset()\n",
    "        done = False\n",
    "        while not done:\n",
    "            act = self.policy(obs)\n",
    "            next_obs, reward, done, _ = self.env.step(act)\n",
    "            observations.append(obs)\n",
    "            actions.append(act)\n",
    "            rewards.append(reward)\n",
    "            obs = next_obs\n",
    "\n",
    "        assert len(actions) == len(observations)\n",
    "        assert len(actions) == len(rewards)\n",
    "\n",
    "        occured_state_action_pair = set()\n",
    "        length = len(actions)\n",
    "        value = 0\n",
    "        for i in reversed(range(length)):\n",
    "            # if length = 10, then i = 9, 8, ..., 0\n",
    "\n",
    "            obs = observations[i]\n",
    "            act = actions[i]\n",
    "            reward = rewards[i]\n",
    "\n",
    "            # [TODO] compute the value reversely\n",
    "            # hint: value(t) = gamma * value(t+1) + r(t)\n",
    "            value = reward\n",
    "            for j in range(i+1, length):\n",
    "                value += self.gamma ** (j-i) * rewards[j]\n",
    "\n",
    "            if (obs, act) not in occured_state_action_pair:\n",
    "                occured_state_action_pair.add((obs, act))\n",
    "\n",
    "                # [TODO] append current return (value) to dict\n",
    "                # hint: `value` represents the future return due to \n",
    "                #  current (obs, act), so we need to store this value\n",
    "                #  in trainer.returns\n",
    "                self.returns[(obs, act)].append(value)\n",
    "\n",
    "                # [TODO] compute the Q value from self.returns and write it \n",
    "                #  into self.table\n",
    "                self.table[obs][act] = np.average(self.returns[(obs, act)])\n",
    "\n",
    "                # we don't need to update the policy since it is \n",
    "                # automatically adjusted with self.table\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "# Managing configurations of your experiments is important for your research.\n",
    "default_mc_control_config = dict(\n",
    "    max_iteration=20000,\n",
    "    max_episode_length=200,\n",
    "    evaluate_interval=1000,\n",
    "    gamma=0.8,\n",
    "    eps=0.3,\n",
    "    env_name='FrozenLakeNotSlippery-v0'\n",
    ")\n",
    "\n",
    "\n",
    "def mc_control(train_config=None):\n",
    "    config = default_mc_control_config.copy()\n",
    "    if train_config is not None:\n",
    "        config.update(train_config)\n",
    "\n",
    "    trainer = MCControlTrainer(\n",
    "        gamma=config['gamma'],\n",
    "        eps=config['eps'],\n",
    "        max_episode_length=config['max_episode_length'],\n",
    "        env_name=config['env_name']\n",
    "    )\n",
    "\n",
    "    for i in range(config['max_iteration']):\n",
    "        # train the agent\n",
    "        trainer.train()\n",
    "\n",
    "        # evaluate the result\n",
    "        if i % config['evaluate_interval'] == 0:\n",
    "            print(\n",
    "                \"[INFO]\\tIn {} iteration, current mean episode reward is {}.\"\n",
    "                \"\".format(i, trainer.evaluate()))\n",
    "\n",
    "    if trainer.evaluate() < 0.6:\n",
    "        print(\"We expect to get the mean episode reward greater than 0.6. \" \\\n",
    "        \"But you get: {}. Please check your codes.\".format(trainer.evaluate()))\n",
    "\n",
    "    return trainer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.66.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.669.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.664.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.668.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.675.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.683.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.653.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.693.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.64.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.676.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.666.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.682.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.679.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.652.\n",
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.001.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.626.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.68.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.663.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.655.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.683.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.639.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.635.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.658.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.662.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.654.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.675.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.629.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.656.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.673.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.653.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.644.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.677.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.659.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.669.\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "mc_control_trainer = mc_control()\n",
    "\n",
    "sarsa_trainer = sarsa()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "=== The state value for action 0 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.027|0.006|0.028|0.147|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.037|0.000|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.082|0.141|0.261|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.000|0.504|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 1 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.127|0.000|0.318|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.210|0.000|0.506|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.000|0.514|0.746|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.503|0.733|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 2 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.047|0.163|0.108|0.119|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.000|0.000|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.337|0.447|0.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.749|1.000|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n",
      "\n",
      "=== The state value for action 3 ===\n",
      "+-----+-----+-----+-----+-----+\n",
      "|     |   0 |   1 |   2 |   3 |\n",
      "|-----+-----+-----+-----+-----+\n",
      "| 0   |0.037|0.025|0.158|0.102|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 1   |0.051|0.000|0.249|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 2   |0.083|0.000|0.320|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "| 3   |0.000|0.366|0.501|0.000|\n",
      "|     |     |     |     |     |\n",
      "+-----+-----+-----+-----+-----+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "mc_control_trainer.print_table()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run this cell without modification\n",
    "\n",
    "mc_control_trainer.render()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Secion 4 Bonus (optional): Tune and train FrozenLake8x8-v1 with Model-free algorithms\n",
    "\n",
    "You have noticed that we use a simpler environment `FrozenLakeNotSlippery-v0` which has only 16 states and is not stochastic. Can you try to train Model-free families of algorithm using the `FrozenLake8x8-v1` environment? Tune the hyperparameters and compare the results between different algorithms.\n",
    "\n",
    "Hint: It's not easy to train model-free algorithm in `FrozenLake8x8-v1`. Failure is excepted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.0.\n",
      "We expect to get the mean episode reward greater than 0.6. But you get: 0.0. Please check your codes.\n",
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.0.\n",
      "We expect to get the mean episode reward greater than 0.6. But you get: 0.0. Please check your codes.\n",
      "[INFO]\tIn 0 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 1000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 2000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 3000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 4000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 5000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 6000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 7000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 8000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 9000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 10000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 11000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 12000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 13000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 14000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 15000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 16000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 17000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 18000 iteration, current mean episode reward is 0.0.\n",
      "[INFO]\tIn 19000 iteration, current mean episode reward is 0.0.\n",
      "We expect to get the mean episode reward greater than 0.6. But you get: 0.0. Please check your codes.\n"
     ]
    }
   ],
   "source": [
    "# It's ok to leave this cell commented.\n",
    "\n",
    "new_config = dict(\n",
    "    max_iteration=20000,\n",
    "    max_episode_length=200,\n",
    "    learning_rate=0.3,\n",
    "    evaluate_interval=1000,\n",
    "    gamma=0.8,\n",
    "    eps=0.5,\n",
    "    env_name=\"FrozenLake8x8-v1\"\n",
    ")\n",
    "\n",
    "new_mc_control_trainer = mc_control(new_config)\n",
    "new_q_learning_trainer = q_learning(new_config)\n",
    "new_sarsa_trainer = sarsa(new_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now you have implement the MC Control algorithm. You have finished this section. If you want to do more investigation like comparing the policy provided by SARSA, Q-Learning and MC Control, then you can do it in the next cells. It's OK to leave it blank."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  (Right)\n",
      "SFFF\n",
      "FHFH\n",
      "FFFH\n",
      "HFF\u001B[41mG\u001B[0m\n",
      "Current observation: 15\n",
      "Current reward: 1.0\n",
      "Whether we are done: True\n",
      "info: {'prob': 1.0}\n"
     ]
    }
   ],
   "source": [
    "# You can do more investigation here if you wish. Leave it blank if you don't.\n",
    "env = gym.make('FrozenLakeNotSlippery-v0')\n",
    "\n",
    "obs = env.reset()\n",
    "\n",
    "while True:\n",
    "    # take agent action\n",
    "    # [TODO] Uncomment next line\n",
    "    obs, reward, done, info = env.step(sarsa_trainer.policy(obs))\n",
    "\n",
    "    # render the environment\n",
    "    env.render()  # [TODO] Uncomment this line\n",
    "\n",
    "    print(\"Current observation: {}\\nCurrent reward: {}\\n\"\n",
    "          \"Whether we are done: {}\\ninfo: {}\".format(\n",
    "        obs, reward, done, info\n",
    "    ))\n",
    "    wait(sleep=0.5)\n",
    "\n",
    "    # [TODO] terminate the loop if done\n",
    "    if done:\n",
    "        env.reset()\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "------\n",
    "\n",
    "## Conclusion and Discussion\n",
    "\n",
    "It's OK to leave the following cells empty. In the next markdown cell, you can write whatever you like. Like the suggestion on the course, the confusing problems in the assignments, and so on.\n",
    "\n",
    "\n",
    "Following the submission instruction in the assignment to submit your assignment to our staff. Thank you!\n",
    "\n",
    "------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}